對抗生成網路概述

林嶔 (Lin, Chin)

Lesson 15

原理簡介(1)

F15_1

– 從數學上來說,就是形成一個預測函數,而該函數的目標是做隨機亂數與新的物件的映射:

F15_2

原理簡介(2)

F15_3

– 但其實這兩個網路是不可能合在一起one-stage訓練的,你看得出為什麼嗎?

實現一個手寫數字產生器(1)

library(mxnet)

my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.x <- val[-1,]
                                    batch_size <- ncol(val.x)
                                    val.x <- val.x / 255 # Important        
                                    dim(val.x) <- c(28, 28, 1, batch_size)
                                    rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
                                    rand <- array(rand, dim = c(1, 1, 10, batch_size))
                                    rand <- mx.nd.array(rand)
                                    val.x <- mx.nd.array(val.x)
                                    val.y.0 <- array(rep(0, batch_size), dim = c(1, 1, 1, batch_size))
                                    val.y.0 <- mx.nd.array(val.y.0)
                                    val.y.1 <- array(rep(1, batch_size), dim = c(1, 1, 1, batch_size))
                                    val.y.1 <- mx.nd.array(val.y.1)
                                    list(noise = rand, img = val.x, label.0 = val.y.0, label.1 = val.y.1)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter <- my_iterator_func(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)

實現一個手寫數字產生器(2)

– 首先定義Generator:

gen_data <- mx.symbol.Variable('data')

gen_deconv1 <- mx.symbol.Deconvolution(data = gen_data, kernel = c(4, 4), stride = c(2, 2), num_filter = 256, name = 'gen_deconv1')
gen_bn1 <- mx.symbol.BatchNorm(data = gen_deconv1, fix_gamma = TRUE, name = 'gen_bn1')
gen_relu1 <- mx.symbol.Activation(data = gen_bn1, act_type = "relu", name = 'gen_relu1')

gen_deconv2 <- mx.symbol.Deconvolution(data = gen_relu1, kernel = c(3, 3), stride = c(2, 2), pad = c(1, 1), num_filter = 128, name = 'gen_deconv2')
gen_bn2 <- mx.symbol.BatchNorm(data = gen_deconv2, fix_gamma = TRUE, name = 'gen_bn2')
gen_relu2 <- mx.symbol.Activation(data = gen_bn2, act_type = "relu", name = 'gen_relu2')

gen_deconv3 <- mx.symbol.Deconvolution(data = gen_relu2, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 64, name = 'gen_deconv3')
gen_bn3 <- mx.symbol.BatchNorm(data = gen_deconv3, fix_gamma = TRUE, name = 'gen_bn3')
gen_relu3 <- mx.symbol.Activation(data = gen_bn3, act_type = "relu", name = 'gen_relu3')

gen_deconv4 <- mx.symbol.Deconvolution(data = gen_relu3, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 1, name = 'gen_deconv4')
gen_pred <- mx.symbol.Activation(data = gen_deconv4, act_type = "sigmoid", name = 'gen_pred')

– 接著定義Discriminator:

dis_img <- mx.symbol.Variable('img')
dis_label <- mx.symbol.Variable('label')

dis_conv1 <- mx.symbol.Convolution(data = dis_img, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.25, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')

dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.25, name = "dis_relu2")

dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.25, name = "dis_relu3")

dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.25, name = "dis_relu4")

dis_conv5 <- mx.symbol.Convolution(data = dis_relu4, kernel = c(1, 1), num_filter = 1, name = 'dis_conv5')
dis_pred <- mx.symbol.sigmoid(data = dis_conv5, name = 'dis_pred')

– 我們再來定義Loss function,只有Discriminator有Loss function:

eps <- 1e-8
ce_loss_pos <-  mx.symbol.broadcast_mul(mx.symbol.log(dis_pred + eps), dis_label)
ce_loss_neg <-  mx.symbol.broadcast_mul(mx.symbol.log(1 - dis_pred + eps), 1 - dis_label)
ce_loss_mean <- 0 - mx.symbol.mean(ce_loss_pos + ce_loss_neg)
ce_loss <- mx.symbol.MakeLoss(ce_loss_mean, name = 'ce_loss')

實現一個手寫數字產生器(3)

gen_optimizer <- mx.opt.create(name = "adam", learning.rate = 2e-4, beta1 = 0.5, beta2 = 0.999, epsilon = 1e-08, wd = 0)
dis_optimizer <- mx.opt.create(name = "adam", learning.rate = 2e-4, beta1 = 0.5, beta2 = 0.999, epsilon = 1e-08, wd = 0)
gen_executor <- mx.simple.bind(symbol = gen_pred,
                               data = c(1, 1, 10, 32),
                               ctx = mx.cpu(), grad.req = "write")

dis_executor <- mx.simple.bind(symbol = ce_loss,
                               img = c(28, 28, 1, 32), label = c(1, 1, 1, 32),
                               ctx = mx.cpu(), grad.req = "write")
# Initial parameters

mx.set.seed(0)

gen_arg <- mxnet:::mx.model.init.params(symbol = gen_pred,
                                        input.shape = list(data = c(1, 1, 10, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

dis_arg <- mxnet:::mx.model.init.params(symbol = ce_loss,
                                        input.shape = list(img = c(28, 28, 1, 32), label = c(1, 1, 1, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

# Update parameters

mx.exec.update.arg.arrays(gen_executor, gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(gen_executor, gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(dis_executor, dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(dis_executor, dis_arg$aux.params, match.name = TRUE)
gen_updater <- mx.opt.get.updater(optimizer = gen_optimizer, weights = gen_executor$ref.arg.arrays)
dis_updater <- mx.opt.get.updater(optimizer = dis_optimizer, weights = dis_executor$ref.arg.arrays)

實現一個手寫數字產生器(4)

# Generate data

my_iter$reset()
my_iter$iter.next()
## [1] TRUE
my_values <- my_iter$value()

# Generator (forward)
    
mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']]), match.name = TRUE)
mx.exec.forward(gen_executor, is.train = TRUE)
gen_pred_output <- gen_executor$ref.outputs[[1]]

# Discriminator (fake)
    
mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.0']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)

# Discriminator (real)
    
mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], label = my_values[['label.1']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)

# Generator (backward)

mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.1']]), match.name = TRUE)
mx.exec.forward(dis_executor, is.train = TRUE)
mx.exec.backward(dis_executor)
img_grads <- dis_executor$ref.grad.arrays[['img']]
mx.exec.backward(gen_executor, out_grads = img_grads)
gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
library(imager)

par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))

for (i in 1:9) {
  img <- as.array(gen_pred_output)[,,,i]
  plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
}

實現一個手寫數字產生器(5)

set.seed(0)
n.epoch <- 20
logger <- list(gen_loss = NULL, dis_real_loss = NULL, dis_fake_loss = NULL)
for (j in 1:n.epoch) {
  
  current_batch <- 0
  my_iter$reset()
  
  while (my_iter$iter.next()) {
    
    my_values <- my_iter$value()
    
    # Generator (forward)
    
    mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']]), match.name = TRUE)
    mx.exec.forward(gen_executor, is.train = TRUE)
    gen_pred_output <- gen_executor$ref.outputs[[1]]
    
    # Discriminator (fake)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.0']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_fake_loss <- c(logger$dis_fake_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Discriminator (real)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], label = my_values[['label.1']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_real_loss <- c(logger$dis_real_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Generator (backward)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, label = my_values[['label.1']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    img_grads <- dis_executor$ref.grad.arrays[['img']]
    mx.exec.backward(gen_executor, out_grads = img_grads)
    gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
    
    logger$gen_loss <- c(logger$gen_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    if (current_batch %% 100 == 0) {
      
      # Show current images
      
      par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
      for (i in 1:9) {
        img <- as.array(gen_pred_output)[,,,i]
        plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
        rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
      }
      
      # Show loss
      
      message('Epoch [', j, '] Batch [', current_batch, '] Generator-loss = ', formatC(tail(logger$gen_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (real) = ', formatC(tail(logger$dis_real_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (fake) = ', formatC(tail(logger$dis_fake_loss, 1), digits = 5, format = 'f'))
      
    }
    
    current_batch <- current_batch + 1
    
  }
  
  pdf(paste0('result/epoch_', j, '.pdf'), height = 6, width = 6)
  par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
  for (i in 1:9) {
    img <- as.array(gen_pred_output)[,,,i]
    plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
    rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
  }
  dev.off()
  
  gen_model <- list()
  gen_model$symbol <- gen_pred
  gen_model$arg.params <- gen_executor$ref.arg.arrays[-1]
  gen_model$aux.params <- gen_executor$ref.aux.arrays
  class(gen_model) <- "MXFeedForwardModel"
  
  dis_model <- list()
  dis_model$symbol <- dis_pred
  dis_model$arg.params <- dis_executor$ref.arg.arrays[-1]
  dis_model$aux.params <- dis_executor$ref.aux.arrays
  class(dis_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = gen_model, prefix = 'model/gen_v1', iteration = j)
  mx.model.save(model = dis_model, prefix = 'model/dis_v1', iteration = j)
  
}

F15_5

實現一個手寫數字產生器(6)

F15_4

gen_model <- mx.model.load('model/gen_v1', 0)

set.seed(1)

noise_input <- array(rnorm(100), dim = c(1, 1, 10, 10))
pred_img <- predict(gen_model, noise_input)

par(mfrow = c(2, 5), mar = c(0.1, 0.1, 0.1, 0.1))

for (i in 1:10) {
  img <- pred_img[,,,i]
  plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
}

實現一個手寫數字產生器(7)

range_logger <- logger %>% unlist %>% range

plot(logger$gen_loss, type = 'l', col = 'red', lwd = 0.5, ylim = range_logger, xlab = 'Batch', ylab = 'loss')
lines(1:length(logger$dis_real_loss), logger$dis_real_loss, col = 'blue', lwd = 0.5)
lines(1:length(logger$dis_fake_loss), logger$dis_fake_loss, col = 'darkgreen', lwd = 0.5)
legend('topright', c('Gen', 'Real', 'Fake'), col = c('red', 'blue', 'darkgreen'), lwd = 1)

實現一個手寫數字產生器(8)

set.seed(1)

input.1 <- rnorm(10)

noise_input <- array(input.1, dim = c(1, 1, 10, 10))
for (i in 2:10) {
  noise_input[,,,i] <- noise_input[,,,i-1] + 0.1
}

pred_img <- predict(gen_model, noise_input)

par(mfrow = c(2, 5), mar = c(0.1, 0.1, 0.1, 0.1))

for (i in 1:10) {
  img <- pred_img[,,,i]
  plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
}

練習1:實現第一個GAN並且調整超參數

– 另外,你應該有注意到我們剛剛的在訓練對抗生成網路時,在許多地方的超參數設定似乎並不是使用常用的值:

  1. 我們使用了Adam進行優化,但卻採用了0.0002的學習率以及0.5的beta1,另外為什麼不能改成SGD?

  2. 在Generator的部分我們使用了ReLU作非線性轉換,而Discriminator中卻使用LeakyReLU,能統一嗎?

  3. 在Generator的部分最終我們使用了Sigmoid輸出,當然像素是介於0至1的值,但這樣不是會造成梯度消失,能把這個限制打開嗎?

  4. 訓練Discriminator的過程中我們把fake與real分開訓練,是不是覺得怪怪的?

練習1答案(1)

gen_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-3, beta1 = 0.9, beta2 = 0.999, epsilon = 1e-08, wd = 0)
dis_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-3, beta1 = 0.9, beta2 = 0.999, epsilon = 1e-08, wd = 0)

F15_6

練習1答案(2)

  1. Generator的最終輸出是tanh、sigmoid等bounds normalization function

  2. 一個mini-batch中必須確保只有假樣本或者真樣本

  3. 避免稀疏梯度(ReLU、Max Pooling),原則上在Discriminator中使用slope為0.2的Leaky ReLU,在Generator中可以使用ReLU

  4. 使用Adam訓練模型,但調整為0.0002的學習率以及0.5的beta1

  1. 反轉訓練標籤(fake = 1; real = 0)

  2. 使用噪音標籤(fake = 0.9-1.0; real = 0.0-0.1)訓練Discriminator

利用判別器進行轉移特徵學習(1)

– 因此,自編碼器的壓縮器能夠拿來做轉移特徵學習,那當然我們也能試試使用判別器來做轉移特徵學習。

my_iterator_func2 <- setRefClass("Custom_Iter2",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.x <- val[-1,]
                                    dim(val.x) <- c(28, 28, 1, ncol(val.x))
                                    val.x <- val.x/255
                                    val.x <- mx.nd.array(val.x)
                                    val.y <- t(model.matrix(~ -1 + factor(val[1,], levels = 0:9)))
                                    val.y <- array(val.y, dim = c(10, dim(val.x)[4]))
                                    val.y <- mx.nd.array(val.y)
                                    list(data=val.x, label=val.y)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter2 <- my_iterator_func2(iter = NULL,  data.csv = 'data/sub_train_data.csv', data.shape = 785, batch.size = 20)

利用判別器進行轉移特徵學習(2)

data <- mx.symbol.Variable('data')

dis_conv1 <- mx.symbol.Convolution(data = data, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.25, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')

dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.25, name = "dis_relu2")

dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.25, name = "dis_relu3")

dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.25, name = "dis_relu4")

fc1 <- mx.symbol.FullyConnected(data = dis_relu4, num.hidden = 10, name = 'fc1')
softmax <- mx.symbol.softmax(data = fc1, axis = 1, name = 'softmax')

label <- mx.symbol.Variable(name = 'label')

eps <- 1e-8
m_log <- 0 - mx.symbol.mean(mx.symbol.broadcast_mul(mx.symbol.log(softmax + eps), label))
m_logloss <- mx.symbol.MakeLoss(m_log, name = 'm_logloss')
my_optimizer <- mx.opt.create(name = "adam", learning.rate = 0.001, beta1 = 0.9, beta2 = 0.999, wd = 1e-4)
my.eval.metric.loss <- mx.metric.custom(
  name = "mlog-loss", 
  function(real, pred) {
    return(pred)
  }
)

mx.set.seed(0)

model.1 <- mx.model.FeedForward.create(symbol = m_logloss, X = my_iter2, optimizer = my_optimizer,
                                       eval.metric = my.eval.metric.loss,
                                       array.batch.size = 20, ctx = mx.cpu(), num.round = 100)

利用判別器進行轉移特徵學習(3)

library(data.table)

Test.DAT = fread("data/test_data.csv", data.table = FALSE)

Test.X = t(Test.DAT[,-1])
dim(Test.X) = c(28, 28, 1, ncol(Test.X))
Test.X = Test.X/255
Test.Y = Test.DAT[,1]
model.1$symbol <- softmax

predict_Y <- predict(model.1, Test.X)
confusion_table <- table(max.col(t(predict_Y)), Test.Y)
cat("Testing accuracy rate =", sum(diag(confusion_table))/sum(confusion_table))
## Testing accuracy rate = 0.928631
print(confusion_table)
##     Test.Y
##         0    1    2    3    4    5    6    7    8    9
##   1  1619    0    8    7    7    8   27    6   12    7
##   2     0 1823    8    3   14    8    5   16   22    7
##   3     2    8 1559   17    3    4    6   26   15    3
##   4     2    2   31 1660    0   99    1   24   24   15
##   5     1    3    3    0 1433    1   13   24    3   77
##   6     2    0    3   18    0 1380   67    0   20    6
##   7     8    1    3    0   10   10 1534    1   10    0
##   8     1    6   21    9    4   11    0 1567    3   39
##   9    21    7   18   19   17   17    6    8 1544    6
##   10    7    1    2    9  118   13    2   81   22 1482

利用判別器進行轉移特徵學習(4)

– 如果你想要獲得一個訓練好的Discriminator,你可以分別下載dis_v1-0000.params以及dis_v1-symbol.json

dis_model <- mx.model.load('model/dis_v1', 0)

mx.set.seed(0)
new_arg <- mxnet:::mx.model.init.params(symbol = m_logloss,
                                        input.shape = list(data = c(28, 28, 1, 7), label = c(10, 7)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

for (k in 1:12) {
  new_arg$arg.params[[k]] <- dis_model$arg.params[[k]]
}

model.2 <- mx.model.FeedForward.create(symbol = m_logloss, X = my_iter2, optimizer = my_optimizer,
                                       eval.metric = my.eval.metric.loss,
                                       arg.params = new_arg$arg.params,
                                       array.batch.size = 20, ctx = mx.cpu(), num.round = 100)
model.2$symbol <- softmax

predict_Y <- predict(model.2, Test.X)
confusion_table <- table(max.col(t(predict_Y)), Test.Y)
cat("Testing accuracy rate =", sum(diag(confusion_table))/sum(confusion_table))
## Testing accuracy rate = 0.9335119
print(confusion_table)
##     Test.Y
##         0    1    2    3    4    5    6    7    8    9
##   1  1635    0    5    3    5   19   16   14   11   10
##   2     0 1820   10    2   10    7    7   19   13   10
##   3     2    9 1565   26    3    4    5   18   23    3
##   4     2    0   23 1644    0   41    0   22   27   26
##   5     1    4    6    0 1447    1   22   20    3   84
##   6     1    0    1   23    0 1436   50    3   28    9
##   7     5    2    5    1   12   17 1544    1    7    0
##   8     3    4   18    7    7    5    0 1585    0   18
##   9    12   12   21   25   19   14    8   10 1539   14
##   10    2    0    2   11  103    7    9   61   24 1468

條件式對抗生成網路(1)

– 要做到這件事情並不是太難的事情,只要我們在網路的Input增加條件標籤,如此一來就能做到!而這種網路的名稱叫做條件式對抗生成網路(Conditional Generative Adversarial Nets, CGAN)。

gen_data <- mx.symbol.Variable('data')
gen_digit <- mx.symbol.Variable('digit')

gen_concat <- mx.symbol.concat(data = list(gen_data, gen_digit), num.args = 2, dim = 1, name = "gen_concat")

gen_deconv1 <- mx.symbol.Deconvolution(data = gen_concat, kernel = c(4, 4), stride = c(2, 2), num_filter = 256, name = 'gen_deconv1')
gen_bn1 <- mx.symbol.BatchNorm(data = gen_deconv1, fix_gamma = TRUE, name = 'gen_bn1')
gen_relu1 <- mx.symbol.Activation(data = gen_bn1, act_type = "relu", name = 'gen_relu1')

gen_deconv2 <- mx.symbol.Deconvolution(data = gen_relu1, kernel = c(3, 3), stride = c(2, 2), pad = c(1, 1), num_filter = 128, name = 'gen_deconv2')
gen_bn2 <- mx.symbol.BatchNorm(data = gen_deconv2, fix_gamma = TRUE, name = 'gen_bn2')
gen_relu2 <- mx.symbol.Activation(data = gen_bn2, act_type = "relu", name = 'gen_relu2')

gen_deconv3 <- mx.symbol.Deconvolution(data = gen_relu2, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 64, name = 'gen_deconv3')
gen_bn3 <- mx.symbol.BatchNorm(data = gen_deconv3, fix_gamma = TRUE, name = 'gen_bn3')
gen_relu3 <- mx.symbol.Activation(data = gen_bn3, act_type = "relu", name = 'gen_relu3')

gen_deconv4 <- mx.symbol.Deconvolution(data = gen_relu3, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 1, name = 'gen_deconv4')
gen_pred <- mx.symbol.Activation(data = gen_deconv4, act_type = "sigmoid", name = 'gen_pred')

條件式對抗生成網路(2)

dis_img <- mx.symbol.Variable('img')
dis_digit <- mx.symbol.Variable("digit")
dis_label <- mx.symbol.Variable('label')

dis_concat <- mx.symbol.broadcast_mul(lhs = dis_img, rhs = dis_digit, name = 'dis_concat')

dis_conv1 <- mx.symbol.Convolution(data = dis_concat, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.2, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')

dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.2, name = "dis_relu2")

dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.2, name = "dis_relu3")

dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.2, name = "dis_relu4")

dis_conv5 <- mx.symbol.Convolution(data = dis_relu4, kernel = c(1, 1), num_filter = 1, name = 'dis_conv5')
dis_pred <- mx.symbol.sigmoid(data = dis_conv5, name = 'dis_pred')

– 我們再來定義Loss function:

eps <- 1e-8
ce_loss_pos <-  mx.symbol.broadcast_mul(mx.symbol.log(dis_pred + eps), dis_label)
ce_loss_neg <-  mx.symbol.broadcast_mul(mx.symbol.log(1 - dis_pred + eps), 1 - dis_label)
ce_loss_mean <- 0 - mx.symbol.mean(ce_loss_pos + ce_loss_neg)
ce_loss <- mx.symbol.MakeLoss(ce_loss_mean, name = 'ce_loss')

條件式對抗生成網路(3)

my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.x <- val[-1,]
                                    batch_size <- ncol(val.x)
                                    val.x <- val.x / 255 # Important        
                                    dim(val.x) <- c(28, 28, 1, batch_size)
                                    val.x <- mx.nd.array(val.x)
                                    
                                    digit.real <- mx.nd.array(val[1,])
                                    digit.real <- mx.nd.one.hot(indices = digit.real, depth = 10)
                                    digit.real <- mx.nd.reshape(data = digit.real, shape = c(1, 1, -1, batch_size))
                                      
                                    digit.fake <- mx.nd.array(sample(0:9, size = batch_size, replace = TRUE))
                                    digit.fake <- mx.nd.one.hot(indices = digit.fake, depth = 10)
                                    digit.fake <- mx.nd.reshape(data = digit.fake, shape = c(1, 1, -1, batch_size))

                                    rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
                                    rand <- array(rand, dim = c(1, 1, 10, batch_size))
                                    rand <- mx.nd.array(rand)
                                    
                                    label.real <- array(runif(10, 0, 0.1), dim = c(1, 1, 1, batch_size))
                                    label.real <- mx.nd.array(label.real)
                                    label.fake <- array(runif(10, 0.9, 1), dim = c(1, 1, 1, batch_size))
                                    label.fake <- mx.nd.array(label.fake)
                                    label.gen <- array(rep(0, 10), dim = c(1, 1, 1, batch_size))
                                    label.gen <- mx.nd.array(label.gen)
                                    
                                    list(noise = rand, img = val.x, digit.fake = digit.fake, digit.real = digit.real, label.fake = label.fake, label.real = label.real, label.gen = label.gen)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter <- my_iterator_func(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)

條件式對抗生成網路(4)

gen_optimizer <- mx.opt.create(name = "adam", learning.rate = 2e-4, beta1 = 0.5, beta2 = 0.999, epsilon = 1e-08, wd = 0)
dis_optimizer <- mx.opt.create(name = "adam", learning.rate = 2e-4, beta1 = 0.5, beta2 = 0.999, epsilon = 1e-08, wd = 0)
gen_executor <- mx.simple.bind(symbol = gen_pred,
                               data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32),
                               ctx = mx.cpu(), grad.req = "write")

dis_executor <- mx.simple.bind(symbol = ce_loss,
                               img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32),
                               ctx = mx.cpu(), grad.req = "write")
# Initial parameters

mx.set.seed(0)

gen_arg <- mxnet:::mx.model.init.params(symbol = gen_pred,
                                        input.shape = list(data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

dis_arg <- mxnet:::mx.model.init.params(symbol = ce_loss,
                                        input.shape = list(img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

# Update parameters

mx.exec.update.arg.arrays(gen_executor, gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(gen_executor, gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(dis_executor, dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(dis_executor, dis_arg$aux.params, match.name = TRUE)
gen_updater <- mx.opt.get.updater(optimizer = gen_optimizer, weights = gen_executor$ref.arg.arrays)
dis_updater <- mx.opt.get.updater(optimizer = dis_optimizer, weights = dis_executor$ref.arg.arrays)

條件式對抗生成網路(5)

set.seed(0)
n.epoch <- 20
logger <- list(gen_loss = NULL, dis_real_loss = NULL, dis_fake_loss = NULL)
for (j in 1:n.epoch) {
  
  current_batch <- 0
  my_iter$reset()
  
  while (my_iter$iter.next()) {
    
    my_values <- my_iter$value()
    
    # Generator (forward)
    
    mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']], digit = my_values[['digit.fake']]), match.name = TRUE)
    mx.exec.forward(gen_executor, is.train = TRUE)
    gen_pred_output <- gen_executor$ref.outputs[[1]]
    
    # Discriminator (fake)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.fake']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_fake_loss <- c(logger$dis_fake_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Discriminator (real)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], digit = my_values[['digit.real']], label = my_values[['label.real']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_real_loss <- c(logger$dis_real_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Generator (backward)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.gen']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    img_grads <- dis_executor$ref.grad.arrays[['img']]
    mx.exec.backward(gen_executor, out_grads = img_grads)
    gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
    
    logger$gen_loss <- c(logger$gen_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    if (current_batch %% 100 == 0) {
      
      # Show current images
      
      par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
      for (i in 1:9) {
        img <- as.array(gen_pred_output)[,,,i]
        plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
        rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
      }
      
      # Show loss
      
      message('Epoch [', j, '] Batch [', current_batch, '] Generator-loss = ', formatC(tail(logger$gen_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (real) = ', formatC(tail(logger$dis_real_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (fake) = ', formatC(tail(logger$dis_fake_loss, 1), digits = 5, format = 'f'))
      
    }
    
    current_batch <- current_batch + 1
    
  }
  
  pdf(paste0('result/epoch_', j, '.pdf'), height = 6, width = 6)
  par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
  for (i in 1:9) {
    img <- as.array(gen_pred_output)[,,,i]
    plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
    rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
  }
  dev.off()
  
  gen_model <- list()
  gen_model$symbol <- gen_pred
  gen_model$arg.params <- gen_executor$ref.arg.arrays[-c(1:2)]
  gen_model$aux.params <- gen_executor$ref.aux.arrays
  class(gen_model) <- "MXFeedForwardModel"
  
  dis_model <- list()
  dis_model$symbol <- dis_pred
  dis_model$arg.params <- dis_executor$ref.arg.arrays[-c(1:2)]
  dis_model$aux.params <- dis_executor$ref.aux.arrays
  class(dis_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = gen_model, prefix = 'model/cgen_v1', iteration = j)
  mx.model.save(model = dis_model, prefix = 'model/cdis_v1', iteration = j)
  
}

練習2:使用CGAN進行指定數字的生成任務

– 在訓練的過程中,你應該能慢慢發現數字越來越貼近真實的狀況!

F15_7

– 如果你的電腦訓練速度太慢,你可以先下載cgen_v1-0000.params以及cgen_v1-symbol.json得到Generator的參數。

gen_model <- mx.model.load('model/cgen_v1', 0)

練習2答案

my_predict <- function (model, digits = 0:9) {
  
  batch_size <- length(digits)
  
  gen_executor <- mx.simple.bind(symbol = model$symbol,
                                 data = c(1, 1, 10, batch_size), digit = c(1, 1, 10, batch_size),
                                 ctx = mx.cpu())
  
  mx.exec.update.arg.arrays(gen_executor, model$arg.params, match.name = TRUE)
  mx.exec.update.aux.arrays(gen_executor, model$aux.params, match.name = TRUE)
  
  noise_array <- rnorm(batch_size * 10, mean = 0, sd = 1)
  noise_array <- array(noise_array, dim = c(1, 1, 10, batch_size))
  noise_array <- mx.nd.array(noise_array)
  
  digit_array <- mx.nd.array(digits)
  digit_array <- mx.nd.one.hot(indices = digit_array, depth = 10)
  digit_array <- mx.nd.reshape(data = digit_array, shape = c(1, 1, -1, batch_size))
  
  mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = noise_array, digit = digit_array), match.name = TRUE)
  mx.exec.forward(gen_executor, is.train = FALSE)
  gen_pred_output <- gen_executor$ref.outputs[[1]]
  
  return(as.array(gen_pred_output))
  
}
pred_img <- my_predict(model = gen_model, digits = 0:9)

par(mfrow = c(2, 5), mar = c(0.1, 0.1, 0.1, 0.1))

for (i in 1:10) {
  img <- pred_img[,,,i]
  plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
}

利用假資料輔助模型訓練(1)

– 讓我們來嘗試一下吧,但要注意的是剛剛的生成模型實際上使用了25200個標籤資料做訓練,所以我們要重建Baseline模型,而這是Iterator的部分:

my_iter2 <- my_iterator_func2(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 20)
data <- mx.symbol.Variable('data')

dis_conv1 <- mx.symbol.Convolution(data = data, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.25, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')

dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.25, name = "dis_relu2")

dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.25, name = "dis_relu3")

dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.25, name = "dis_relu4")

fc1 <- mx.symbol.FullyConnected(data = dis_relu4, num.hidden = 10, name = 'fc1')
softmax <- mx.symbol.softmax(data = fc1, axis = 1, name = 'softmax')

label <- mx.symbol.Variable(name = 'label')

eps <- 1e-8
m_log <- 0 - mx.symbol.mean(mx.symbol.broadcast_mul(mx.symbol.log(softmax + eps), label))
m_logloss <- mx.symbol.MakeLoss(m_log, name = 'm_logloss')
my_optimizer <- mx.opt.create(name = "adam", learning.rate = 0.001, beta1 = 0.9, beta2 = 0.999, wd = 1e-4)

利用假資料輔助模型訓練(2)

dis_model <- mx.model.load('model/dis_v1', 0)

mx.set.seed(0)
new_arg <- mxnet:::mx.model.init.params(symbol = m_logloss,
                                        input.shape = list(data = c(28, 28, 1, 7), label = c(10, 7)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

for (k in 1:12) {
  new_arg$arg.params[[k]] <- dis_model$arg.params[[k]]
}

model.1 <- mx.model.FeedForward.create(symbol = m_logloss, X = my_iter2, optimizer = my_optimizer,
                                       eval.metric = my.eval.metric.loss,
                                       arg.params = new_arg$arg.params,
                                       array.batch.size = 20, ctx = mx.cpu(), num.round = 20)
model.1$symbol <- softmax

predict_Y <- predict(model.1, Test.X)
confusion_table <- table(max.col(t(predict_Y)), Test.Y)
cat("Testing accuracy rate =", sum(diag(confusion_table))/sum(confusion_table))
## Testing accuracy rate = 0.9856548
print(confusion_table)
##     Test.Y
##         0    1    2    3    4    5    6    7    8    9
##   1  1652    0    2    0    4    2    8    1    1    7
##   2     0 1839    4    2    1    0    0    8    2    1
##   3     0    4 1641    7    1    1    2   15    5    0
##   4     0    1    1 1713    0    8    1    4    3    6
##   5     0    3    1    0 1581    0    1    7    1   10
##   6     1    0    0   10    0 1530    1    0    6    5
##   7     2    0    0    1    2    4 1645    0    6    0
##   8     0    0    2    3    3    0    0 1708    0   10
##   9     6    3    5    4    2    4    3    4 1648    1
##   10    2    1    0    2   12    2    0    6    3 1602

利用假資料輔助模型訓練(3)

gen_model <- mx.model.load('model/cgen_v1', 0)

my_predict <- function (model, digits = 0:9) {
  
  batch_size <- length(digits)
  
  gen_executor <- mx.simple.bind(symbol = model$symbol,
                                 data = c(1, 1, 10, batch_size), digit = c(1, 1, 10, batch_size),
                                 ctx = mx.cpu())
  
  mx.exec.update.arg.arrays(gen_executor, model$arg.params, match.name = TRUE)
  mx.exec.update.aux.arrays(gen_executor, model$aux.params, match.name = TRUE)
  
  noise_array <- rnorm(batch_size * 10, mean = 0, sd = 1)
  noise_array <- array(noise_array, dim = c(1, 1, 10, batch_size))
  noise_array <- mx.nd.array(noise_array)
  
  digit_array <- mx.nd.array(digits)
  digit_array <- mx.nd.one.hot(indices = digit_array, depth = 10)
  digit_array <- mx.nd.reshape(data = digit_array, shape = c(1, 1, -1, batch_size))
  
  mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = noise_array, digit = digit_array), match.name = TRUE)
  mx.exec.forward(gen_executor, is.train = FALSE)
  gen_pred_output <- gen_executor$ref.outputs[[1]]
  
  return(as.array(gen_pred_output))
  
}

my_iterator_func3 <- setRefClass("Custom_Iter3",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.y <- val[1,]
                                    if (sample(0:1, 1) == 1) {
                                      val.x <- my_predict(model = gen_model, digits = val.y)
                                    } else {
                                      val.x <- val[-1,]
                                      dim(val.x) <- c(28, 28, 1, ncol(val.x))
                                      val.x <- val.x/255
                                    }
                                    val.x <- mx.nd.array(val.x)
                                    val.y <- t(model.matrix(~ -1 + factor(val.y, levels = 0:9)))
                                    val.y <- array(val.y, dim = c(10, dim(val.x)[4]))
                                    val.y <- mx.nd.array(val.y)
                                    list(data=val.x, label=val.y)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter3 <- my_iterator_func3(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 20)

利用假資料輔助模型訓練(4)

dis_model <- mx.model.load('model/dis_v1', 0)

mx.set.seed(0)
new_arg <- mxnet:::mx.model.init.params(symbol = m_logloss,
                                        input.shape = list(data = c(28, 28, 1, 7), label = c(10, 7)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

for (k in 1:12) {
  new_arg$arg.params[[k]] <- dis_model$arg.params[[k]]
}

model.2 <- mx.model.FeedForward.create(symbol = m_logloss, X = my_iter3, optimizer = my_optimizer,
                                       eval.metric = my.eval.metric.loss,
                                       arg.params = new_arg$arg.params,
                                       array.batch.size = 20, ctx = mx.cpu(), num.round = 20)
model.2$symbol <- softmax

predict_Y <- predict(model.2, Test.X)
confusion_table <- table(max.col(t(predict_Y)), Test.Y)
cat("Testing accuracy rate =", sum(diag(confusion_table))/sum(confusion_table))
## Testing accuracy rate = 0.9817857
print(confusion_table)
##     Test.Y
##         0    1    2    3    4    5    6    7    8    9
##   1  1656    0    6    1    5    3   17    2   13   10
##   2     0 1840    5    3    5    1    0    5    3    1
##   3     1    3 1627    3    0    1    2    8    9    2
##   4     0    0   10 1705    0    4    1    3    3    3
##   5     0    0    0    0 1576    0    2    4    3   14
##   6     3    3    2   20    0 1534    7    0   12    8
##   7     0    1    0    0    4    6 1631    0   12    1
##   8     1    3    4    4    2    1    0 1725    1   17
##   9     0    1    2    3    2    0    1    2 1616    2
##   10    2    0    0    3   12    1    0    4    3 1584

– 當然,也許你會說是我們的生成模型還有訓練空間,但這個實驗也大致能說明假資料其實對訓練監督模型的幫助非常有限。

結語

– 另外,如果你想要利用對抗生成網路的無監督特性協助監督模型的訓練,你會發現目前為止效果其實跟自編碼器差別不大,所以他並不是這麼簡單的就能輕易應用。

– 目前的神經網路大多需要極大量的標註樣本才能訓練得好,但人類的學習過程中似乎並不需要如此大量的標註資訊,因此如何利用無監督模型輔助模型訓練是目前研究的重要熱點。

F15_8